{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '3'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "dimension = 400\n",
    "vocab = \"EOS abcdefghijklmnopqrstuvwxyz'\"\n",
    "char2idx = {char: idx for idx, char in enumerate(vocab)}\n",
    "idx2char = {idx: char for idx, char in enumerate(vocab)}\n",
    "\n",
    "def text2idx(text):\n",
    "    text = re.sub(r'[^a-z ]', '', text.lower()).strip()\n",
    "    converted = [char2idx[char] for char in text]\n",
    "    return text, converted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = 1\n",
    "PAD = 0\n",
    "EOS = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "train_X, train_Y = [], []\n",
    "text_files = [f for f in os.listdir('spectrogram-train') if f.endswith('.npy')]\n",
    "for fpath in text_files:\n",
    "    try:\n",
    "        splitted = fpath.split('-')\n",
    "        if len(splitted) == 2:\n",
    "            splitted[1] = splitted[1].split('.')[1]\n",
    "            fpath = splitted[0] + '.' + splitted[1]\n",
    "        with open('data/' + fpath.replace('npy', 'txt')) as fopen:\n",
    "            text, converted = text2idx(fopen.read())\n",
    "        w = np.load('spectrogram-train/' + fpath)\n",
    "        if w.shape[1] != dimension:\n",
    "            continue\n",
    "        train_X.append(w)\n",
    "        train_Y.append(converted)\n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_X, test_Y = [], []\n",
    "text_files = [f for f in os.listdir('spectrogram-test') if f.endswith('.npy')]\n",
    "for fpath in text_files:\n",
    "    with open('data/' + fpath.replace('npy', 'txt')) as fopen:\n",
    "        text, converted = text2idx(fopen.read())\n",
    "    w = np.load('spectrogram-test/' + fpath)\n",
    "    if w.shape[1] != dimension:\n",
    "        continue\n",
    "    test_X.append(w)\n",
    "    test_Y.append(converted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        num_layers,\n",
    "        size_layer,\n",
    "        learning_rate,\n",
    "        num_features,\n",
    "        dropout = 1.0,\n",
    "        beam_width=5, force_teaching_ratio=0.5\n",
    "    ):\n",
    "        \n",
    "        def lstm_cell(size, reuse=False):\n",
    "            return tf.nn.rnn_cell.LSTMCell(size, initializer=tf.orthogonal_initializer(),reuse=reuse)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.float32, [None, None, num_features])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.X_seq_len = tf.reduce_mean(self.X_seq_len, axis = 1)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype=tf.int32)\n",
    "        \n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        decoder_embeddings = tf.Variable(tf.random_uniform([len(char2idx), size_layer], -1, 1))\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        \n",
    "        self.encoder_out = self.X\n",
    "        print(self.X_seq_len)\n",
    "\n",
    "        for n in range(num_layers):\n",
    "            (out_fw, out_bw), (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(\n",
    "                cell_fw = lstm_cell(size_layer // 2),\n",
    "                cell_bw = lstm_cell(size_layer // 2),\n",
    "                inputs = self.encoder_out,\n",
    "                sequence_length = self.X_seq_len,\n",
    "                dtype = tf.float32,\n",
    "                scope = 'bidirectional_rnn_%d'%(n))\n",
    "            self.encoder_out = tf.concat((out_fw, out_bw), 2)\n",
    "            \n",
    "        bi_state_c = tf.concat((state_fw.c, state_bw.c), -1)\n",
    "        bi_state_h = tf.concat((state_fw.h, state_bw.h), -1)\n",
    "        bi_lstm_state = tf.nn.rnn_cell.LSTMStateTuple(c=bi_state_c, h=bi_state_h)\n",
    "        encoder_state = tuple([bi_lstm_state] * num_layers)\n",
    "        \n",
    "        print(self.encoder_out, encoder_state)\n",
    "        \n",
    "        with tf.variable_scope('decode'):\n",
    "            attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(\n",
    "            num_units = size_layer, \n",
    "            memory = self.encoder_out,\n",
    "            memory_sequence_length = self.X_seq_len)\n",
    "            decoder_cell = tf.contrib.seq2seq.AttentionWrapper(\n",
    "                cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell(size_layer) for _ in range(num_layers)]),\n",
    "                attention_mechanism = attention_mechanism,\n",
    "                attention_layer_size = size_layer)\n",
    "            training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(\n",
    "            inputs = tf.nn.embedding_lookup(decoder_embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                embedding = decoder_embeddings,\n",
    "                sampling_probability = 1 - force_teaching_ratio,\n",
    "                time_major = False)\n",
    "            training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cell,\n",
    "                helper = training_helper,\n",
    "                initial_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state),\n",
    "                output_layer = tf.layers.Dense(len(char2idx)))\n",
    "            training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "            self.training_logits = training_decoder_output.rnn_output\n",
    "            \n",
    "        \n",
    "        with tf.variable_scope('decode', reuse=True):\n",
    "            encoder_out_tiled = tf.contrib.seq2seq.tile_batch(self.encoder_out, beam_width)\n",
    "            encoder_state_tiled = tf.contrib.seq2seq.tile_batch(encoder_state, beam_width)\n",
    "            X_seq_len_tiled = tf.contrib.seq2seq.tile_batch(self.X_seq_len, beam_width)\n",
    "            attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(\n",
    "                num_units = size_layer, \n",
    "                memory = encoder_out_tiled,\n",
    "                memory_sequence_length = X_seq_len_tiled)\n",
    "            decoder_cell = tf.contrib.seq2seq.AttentionWrapper(\n",
    "                cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell(size_layer, reuse=True) for _ in range(num_layers)]),\n",
    "                attention_mechanism = attention_mechanism,\n",
    "                attention_layer_size = size_layer)\n",
    "            predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(\n",
    "                cell = decoder_cell,\n",
    "                embedding = decoder_embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS,\n",
    "                initial_state = decoder_cell.zero_state(batch_size * beam_width, tf.float32).clone(cell_state = encoder_state_tiled),\n",
    "                beam_width = beam_width,\n",
    "                output_layer = tf.layers.Dense(len(char2idx), _reuse=True),\n",
    "                length_penalty_weight = 0.0)\n",
    "            predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = False,\n",
    "                maximum_iterations = tf.reduce_max(self.X_seq_len))\n",
    "            self.predicting_ids = predicting_decoder_output.predicted_ids[:, :, 0]\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0830 21:27:09.141949 140382336018240 deprecation.py:506] From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n",
      "W0830 21:27:09.177374 140382336018240 deprecation.py:323] From <ipython-input-6-124b2006545b>:13: LSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n",
      "W0830 21:27:09.179373 140382336018240 deprecation.py:323] From <ipython-input-6-124b2006545b>:36: bidirectional_dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API\n",
      "W0830 21:27:09.180603 140382336018240 deprecation.py:323] From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py:464: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n",
      "W0830 21:27:09.294915 140382336018240 deprecation.py:506] From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py:961: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"Mean:0\", shape=(?,), dtype=int32)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0830 21:27:09.940955 140382336018240 deprecation.py:323] From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py:244: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"concat_2:0\", shape=(?, ?, 512), dtype=float32) (LSTMStateTuple(c=<tf.Tensor 'concat_3:0' shape=(?, 512) dtype=float32>, h=<tf.Tensor 'concat_4:0' shape=(?, 512) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'concat_3:0' shape=(?, 512) dtype=float32>, h=<tf.Tensor 'concat_4:0' shape=(?, 512) dtype=float32>))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0830 21:27:11.023870 140382336018240 lazy_loader.py:50] \n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "W0830 21:27:11.044019 140382336018240 deprecation.py:506] From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "W0830 21:27:11.318089 140382336018240 deprecation.py:323] From <ipython-input-6-124b2006545b>:52: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.\n",
      "W0830 21:27:12.459459 140382336018240 deprecation.py:323] From /home/husein/.local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/helper.py:107: multinomial (from tensorflow.python.ops.random_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.random.categorical` instead.\n",
      "W0830 21:27:13.310252 140382336018240 deprecation.py:323] From /home/husein/.local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py:985: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/gradients_util.py:93: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "\n",
    "size_layers = 512\n",
    "learning_rate = 1e-3\n",
    "num_layers = 2\n",
    "batch_size = 64\n",
    "epoch = 20\n",
    "\n",
    "model = Model(num_layers, size_layers, learning_rate, dimension)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "    train_X, dtype = 'float32', padding = 'post'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_X = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "    test_X, dtype = 'float32', padding = 'post'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:54<00:00,  3.86it/s, accuracy=0.878, cost=0.456]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  7.77it/s, accuracy=0.839, cost=0.5]  \n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, training avg cost 0.700397, training avg accuracy 0.782703\n",
      "epoch 1, testing avg cost 0.510297, testing avg accuracy 0.834116\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  3.90it/s, accuracy=0.914, cost=0.198]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.47it/s, accuracy=0.867, cost=0.455]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, training avg cost 0.341112, training avg accuracy 0.886407\n",
      "epoch 2, testing avg cost 0.453416, testing avg accuracy 0.864056\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  3.90it/s, accuracy=0.971, cost=0.0704]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.61it/s, accuracy=0.854, cost=0.608]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, training avg cost 0.140593, training avg accuracy 0.955247\n",
      "epoch 3, testing avg cost 0.518977, testing avg accuracy 0.866511\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  4.07it/s, accuracy=0.986, cost=0.0522]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:00<00:00,  8.97it/s, accuracy=0.854, cost=0.66] \n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, training avg cost 0.063024, training avg accuracy 0.980905\n",
      "epoch 4, testing avg cost 0.558479, testing avg accuracy 0.873320\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:54<00:00,  3.90it/s, accuracy=0.993, cost=0.0232] \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.54it/s, accuracy=0.859, cost=0.838]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, training avg cost 0.033872, training avg accuracy 0.990128\n",
      "epoch 5, testing avg cost 0.626188, testing avg accuracy 0.879525\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  3.85it/s, accuracy=1, cost=0.00884]    \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.56it/s, accuracy=0.878, cost=0.731]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, training avg cost 0.019692, training avg accuracy 0.994477\n",
      "epoch 6, testing avg cost 0.644497, testing avg accuracy 0.881163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  3.85it/s, accuracy=0.993, cost=0.012]  \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.65it/s, accuracy=0.874, cost=0.727]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, training avg cost 0.017062, training avg accuracy 0.995187\n",
      "epoch 7, testing avg cost 0.665107, testing avg accuracy 0.880544\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  3.91it/s, accuracy=1, cost=0.0106]     \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.43it/s, accuracy=0.872, cost=0.711]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, training avg cost 0.016758, training avg accuracy 0.995547\n",
      "epoch 8, testing avg cost 0.683401, testing avg accuracy 0.881203\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  3.91it/s, accuracy=1, cost=0.000962]   \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.49it/s, accuracy=0.874, cost=0.775]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, training avg cost 0.013883, training avg accuracy 0.996235\n",
      "epoch 9, testing avg cost 0.680090, testing avg accuracy 0.883603\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  3.88it/s, accuracy=1, cost=0.000373]   \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.51it/s, accuracy=0.872, cost=0.804]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, training avg cost 0.006105, training avg accuracy 0.998362\n",
      "epoch 10, testing avg cost 0.693695, testing avg accuracy 0.884643\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  3.89it/s, accuracy=1, cost=0.00129]    \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.57it/s, accuracy=0.878, cost=0.798]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11, training avg cost 0.005306, training avg accuracy 0.998665\n",
      "epoch 11, testing avg cost 0.717289, testing avg accuracy 0.885421\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:52<00:00,  3.87it/s, accuracy=1, cost=0.00665]    \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.60it/s, accuracy=0.856, cost=0.829]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12, training avg cost 0.013640, training avg accuracy 0.996345\n",
      "epoch 12, testing avg cost 0.735101, testing avg accuracy 0.877925\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  3.76it/s, accuracy=1, cost=0.000336]   \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.54it/s, accuracy=0.891, cost=0.717]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13, training avg cost 0.016174, training avg accuracy 0.995488\n",
      "epoch 13, testing avg cost 0.703543, testing avg accuracy 0.886818\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:53<00:00,  3.85it/s, accuracy=0.871, cost=0.348]  \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.07it/s, accuracy=0.841, cost=0.462]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14, training avg cost 0.567979, training avg accuracy 0.906560\n",
      "epoch 14, testing avg cost 0.450457, testing avg accuracy 0.851585\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:55<00:00,  3.79it/s, accuracy=0.964, cost=0.105] \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.00it/s, accuracy=0.893, cost=0.41] \n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15, training avg cost 0.204097, training avg accuracy 0.934336\n",
      "epoch 15, testing avg cost 0.401671, testing avg accuracy 0.891277\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:55<00:00,  3.66it/s, accuracy=1, cost=0.00893]    \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.28it/s, accuracy=0.883, cost=0.565]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16, training avg cost 0.029877, training avg accuracy 0.992118\n",
      "epoch 16, testing avg cost 0.495280, testing avg accuracy 0.892635\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:55<00:00,  3.67it/s, accuracy=1, cost=0.000908]   \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.18it/s, accuracy=0.888, cost=0.54] \n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17, training avg cost 0.007138, training avg accuracy 0.998470\n",
      "epoch 17, testing avg cost 0.509342, testing avg accuracy 0.895123\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:55<00:00,  3.70it/s, accuracy=1, cost=0.000807]   \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.35it/s, accuracy=0.889, cost=0.526]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18, training avg cost 0.001952, training avg accuracy 0.999641\n",
      "epoch 18, testing avg cost 0.559849, testing avg accuracy 0.893137\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:55<00:00,  3.82it/s, accuracy=0.993, cost=0.0153] \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  8.11it/s, accuracy=0.871, cost=0.646]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19, training avg cost 0.001879, training avg accuracy 0.999579\n",
      "epoch 19, testing avg cost 0.585445, testing avg accuracy 0.890329\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [00:55<00:00,  3.84it/s, accuracy=1, cost=0.00508]    \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:01<00:00,  7.84it/s, accuracy=0.897, cost=0.569]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, training avg cost 0.016320, training avg accuracy 0.995586\n",
      "epoch 20, testing avg cost 0.576143, testing avg accuracy 0.892865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_cost, train_accuracy, test_cost, test_accuracy = [], [], [], []\n",
    "    for i in pbar:\n",
    "        batch_x = train_X[i : min(i + batch_size, len(train_X))]\n",
    "        y = train_Y[i : min(i + batch_size, len(train_X))]\n",
    "        batch_y, _ = pad_sentence_batch(y, 0)\n",
    "        _, cost, accuracy = sess.run(\n",
    "            [model.optimizer, model.cost, model.accuracy],\n",
    "            feed_dict = {model.X: batch_x, model.Y: batch_y},\n",
    "        )\n",
    "        train_cost.append(cost)\n",
    "        train_accuracy.append(accuracy)\n",
    "        pbar.set_postfix(cost = cost, accuracy = accuracy)\n",
    "    \n",
    "    pbar = tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'testing minibatch loop')\n",
    "    for i in pbar:\n",
    "        batch_x = test_X[i : min(i + batch_size, len(test_X))]\n",
    "        y = test_Y[i : min(i + batch_size, len(test_X))]\n",
    "        batch_y, _ = pad_sentence_batch(y, 0)\n",
    "        cost, accuracy = sess.run(\n",
    "            [model.cost, model.accuracy],\n",
    "            feed_dict = {model.X: batch_x, model.Y: batch_y},\n",
    "        )\n",
    "        \n",
    "        test_cost.append(cost)\n",
    "        test_accuracy.append(accuracy)\n",
    "        \n",
    "        pbar.set_postfix(cost = cost, accuracy = accuracy)\n",
    "    print('epoch %d, training avg cost %f, training avg accuracy %f'%(e + 1, np.mean(train_cost), \n",
    "                                                                      np.mean(train_accuracy)))\n",
    "    \n",
    "    print('epoch %d, testing avg cost %f, testing avg accuracy %f'%(e + 1, np.mean(test_cost), \n",
    "                                                                    np.mean(test_accuracy)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "real: say the word shout\n",
      "predicted: say the word shouthe woutho wo\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "random_index = random.randint(0, len(test_X) - 1)\n",
    "batch_x = test_X[random_index : random_index + 1]\n",
    "print(\n",
    "    'real:',\n",
    "    ''.join(\n",
    "        [idx2char[no] for no in test_Y[random_index : random_index + 1][0]]\n",
    "    ),\n",
    ")\n",
    "pred = sess.run(model.predicting_ids, feed_dict = {model.X: batch_x})[0]\n",
    "print('predicted:', ''.join([idx2char[no] for no in pred]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
